Explainable AI - Model Explanation

Model & Data

The model we decided to explain is a Random Forest classifier. The Random Forest is an ensemble method which relies on multitude of decision trees. Even though a decision tree is a white-box model, Random Forest becomes hard to visualize and interpret once it reaches a certain level of complexity. Random Forest is widely used and is one of the elementary machine learning models. That is why we decided to focus our efforts on trying to find methods which would help to make this model class more interpretable.

The goal of our model is to predict whether a patient is at a high risk of getting a heart attack in the future. We trained our model on a simple heart attack dataset which contains patients’ medical information. We chose this dataset because we believe that medicine is one of the areas where interpretability of a model is crucial. We downloaded the dataset from Kaggle (https://www.kaggle.com/rashikrahmanpritom/heart-attack-analysis-prediction-dataset), from the user Rashik Rahman. During the training of our model, we were not trying to achieve a perfect model, on the contrary, for example, we intentionally left in the dataset redundant features to show that our methods for explaining the model can point out their redundancy.

Detailed description of the features in the dataset:

Model Training

Global Feature Importance

In the plot above the overall feature importance of each feature in the model is shown. The plot is sorted by the mean permutation importance, additionally the impurity based feature importance of the random forest model is shown (as rescaled barchart).

These are two ways to visualize which features the model deems important and which not. Permutation importance sometimes believes a feature to be unimportant or less important then it is if there is correlation with other features. Therefore, correlation needs to be checked to determine whether features are really unimportant. Impurity based feature importance normally has a bias towards features with a huge numerical range in contrast to features with a small number of possible values.

Due to the weaknesses described above this graph cannot be used to make any kind of definitive statement about the model on its own. However, it can be used to gain a feeling about what the model deems important and if applied to a model without correlation between features a statement about the model could be made (at this point we have no information about the correlation between features).

There do not seem to be any strong correlations between features, therefore, the permutation importance should be reliable.

SHAP Values

For a deeper understanding of the models decisions we can compute SHAP values. SHAP values give information about how strong features contributed to a decision and in what way.

Below there is an example, in which the model came to the conclusion that the patient has a heart disease with 70% certainty:

The most important features which contributed to the decision are:

If the decision or more specifically the features which the decision was based on seem unreasonable the decision has to be made manually. If the final decision is the same as that of the model one could try to figure out why there is a correlation and if there is a valid reasoning that can be done or if it is just an error in the model for example due to over-fitting.

Below there are the SHAP values of the whole dataset which was used:

Now let's try to extract information about the model's decisions in them more global manner. The plot below shows the SHAP values grouped by feature instead of sample:

More blue dots are lower values, while more red dots are higher values. The position on the X axis shows the impact of the feature on the decision in its respecive context. The values are sorted by their overall influence on the decision of the model (in regard to this dataset).

The strongest indicator for the decision (on avarage) seems to be cp, which is the type of chest pain the patient has. cp can take four different values, three of which indicate heart disease, the last one (most likely no pain) indicates no heart disease. In a similar manner all features can be said to have low values indicating one tendency and high values indicating the other, this is more clear for discrete values than for values with huge ranges. Age seems to be an exception to some extent because both low values and high values can be found on the same side of the plot.

The Trust Mashine (Experimental)

The method is intended to give non-expert model users confidence about the model's decision by the fact that there are many similar cases which also were classified correctly like this.

Below the SHAP values of the dataset on our model were taken and projected using t-SNE. the hypothesis is, that values which are not near the decision border and in dense clusters will probably be safe decisions which can be trusted.

No hyperparameter-tuning was done since the default look ok and we do not have much expetancy about how many clusters should be visible.

Next we look for clusters in the projection:

The red line is our gues for the decision border. The blue clusters seem trustworthy acording to our hypothesis. The violet clusters might be real but they might also be artefacts or to close to the decision border, since or model only has an accuracy o 88% and mainly false-positives (unhealthy=positive; this is visible in the check below).

Lastly we check whether our hypothesis is correct by checking the true labels:

The blue clusters turned out to be valid. However, the lower the accuracy of the model and the bigger the dataset you are working with the more often there will be cases where a cluster which is identified as "probably safe" contains data points which are wrongly classified. Therefore, you could assign each cluster some kind of certainty value which you evaluate based on a huge dataset. This way the method can be applied to problems where classes are partly overlapping.

LIME

The main idea of this method is to approximate complicated black-box model by a simpler glass-box one. Usually used with problems having very large explanatory variables. In that case the simpler glass-box model is easier to interpret. The main idea is to train a simpler glass-box model on artificial data so that it approximates the predictions of a complicated black-box model (on these data). We decided to use this method in our problem, as medical data usually are complicated and contain a lot various variables, which are important for explanations. We learned about this method in Molnar, 2021. Interpretable Machine Learning book and in Biecek and Burzykowski, 2020. Explanatory Model Analysis book.

WHY, WHO, WHAT, WHEN, WHERE, HOW method:

Below we implement LIME method and present some of the local explanations.

Above shown three results are examples of different probabilities and certainty of the model. In the row 7 & 8 our model was really sure, that the prediction made is correct. So it has nearly no problems in saying, if provided person has high risk of having heart attack or not. Most (or in 8 row all) features are thought to have values which are typical for predicted class. In the row number 4 the situation is much more different. There the model is only 55% sure of its decision. Some of the values of the features also seems to be typical for ill persons. At the end, here the difference is not that big, it seems more like luck if the model chooses correctly.

We also decided to check which features generally for all test data are considered by method as the ones which usually influence the model the most.

Most Important Feature

We take into consideration rows grouped by the predicion. We then try to answer questions: what was the most important feature when predicted illness and what if healthy?

As we see, the values of the features considered are mostly common in a group, at least for ill people. Most of them had less or equal to 0 number of major vessels, which means they had 0, because the values are between 0 and 3. The name of the feature seems to say us, why it was the most important feature. For people considered as more healthy, there was not such a big agreement, which feature helps the most with predictions. Some of the values of 'oldpeak', 'caa' and 'cp' were considered as the most commonly chosen as the most important.

However, the amount of data in our set is not that big and because of that we decided to look on the summary of the most used but first three features for every group.

Three Most Important Features

Here also for ill people number of major vesses equals to 0 seems to be the most important factor, but also apeared that small values of 'thall' are one of the main factors. Also for healthier people 'oldpeak' and 'caa' seem to be the most useful features while carrying out a classification. What is interesting and for sure not intuitive, is that in these results we see nothing about age of the patient. So maybe it is worth trying to train the model without this variable and see what happens?

Different Model

As stated before, we wanted to try what would happen with the accuracy of the model if we delete feature 'age'. To see it, we will not only look on the accuracy of the model, but also compare number of health and ill people predicted for both situations. It is because we can try find out how many of them will change if we delete 'age' from training.

To sum up above results: the accuracy did not change significantly. Also not many labels are different - only three out of ninty one changed. It means that variable 'age' does not contain much new informations relevant to the model, which other features do not contain. We believe that the older the person, the statistically worster its health, and the values of other features show this. That is why the results do not change much when we delete this feature.

Explainable Matrix

The concept of Explainable Matrix for Random Forest interpretability was first introduced by Neto and Paulovich (2020). This visualization method uses a matrix representation to visualize large quantities of Random Forest's decision rules. Using this method, one can create both local and global explanations and thus analyze either an entire model or audit concrete classification results.

Global Explanation

The goal of global explanation is to provide description of the whole Random Forest based on its decision rules. Figure 5 contains the global explanation of our forest. The forest contains in total 1543 rules which rely on 13 features. In the matrix, the rules are ordered based on the rule coverage and features based on their importance. At the first glance, one can clearly see that features resting electrocardiographic results (restecg) and fasting blood sugar (fbs) are used only in few instances and thus have a low importance. Removing these features may prove to be beneficial for the simplicity and overall performance of the forest. By looking closely at individual features and rules, patters in the predicate ranges emerge. These patterns become more pronounced once the focus is on rules with higher coverage. Figure 6 provides the same view as figure 5 but only with the rules with coverage greater than 0.15. For the most important feature, maximum heart rate achieved (thalach), the predicate ranges indicate that patients with higher values tend to be classified as being at higher risk of getting a heart attack. On the other hand, for example, higher risk patients tend to have lower oldpeak value then the other low risk patients.

Local Explanation

In a hypothetical scenario, a doctor is informed that his patient, based on his current medical records, may be at a high risk of getting a heart attack. After looking at the patient’s records, the doctor is not fully convinced with the model’s decision and decides to contact the technical support to ask them to audit the model’s decision. Technician from the technical support team creates a local explanation using the Explainable Matrix (figure 7) for this specific patient and inspects the decision rules which led to the conclusion that this patient is at a high risk. Even though the patient's values for the two most important features (thalach, oldpeak) are often at the borders of the predicate ranges, majority of the rules with high coverage classify the patient as being at risk with complete certainty. This decision is contradicted by some more specialized rules; however, their rule certainty is often very low. Looking at these results, the technician is confident with the model’s decision and informs the doctor.